{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-46yeargkQ-4"
   },
   "source": [
    "##### Copyright 2024 Google LLC."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "gWSrDNlo86sD"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\")\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5BnvlHhhBsbO"
   },
   "source": [
    "<table align=\"left\">\n",
    "  <td style=\"text-align: center\">\n",
    "    <a href=\"https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-gemini-2-spatial-reasoning&utm_medium=aRT-clicks&utm_campaign=gemini-2-spatial-reasoning&destination=gemini-2-spatial-reasoning&url=https%3A%2F%2Fcolab.research.google.com%2Fgithub%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fblob%2Fmain%2Fgenai-on-vertex-ai%2Fgemini%2Fprompting_recipes%2Fspatial_reasoning%2Fspatial_reasoning_app_for_gemini2.ipynb\">\n",
    "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
    "    </a>\n",
    "  </td>\n",
    "  <td style=\"text-align: center\">\n",
    "    <a href=\"https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-gemini-2-spatial-reasoning&utm_medium=aRT-clicks&utm_campaign=gemini-2-spatial-reasoning&destination=gemini-2-spatial-reasoning&url=https%3A%2F%2Fconsole.cloud.google.com%2Fvertex-ai%2Fcolab%2Fimport%2Fhttps%3A%252F%252Fraw.githubusercontent.com%252FGoogleCloudPlatform%252Fapplied-ai-engineering-samples%252Fmain%252Fgenai-on-vertex-ai%252Fgemini%252Fprompting_recipes%252Fspatial_reasoning%252Fspatial_reasoning_app_for_gemini2.ipynb\">\n",
    "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
    "    </a>\n",
    "  </td>\n",
    "  <td style=\"text-align: center\">\n",
    "    <a href=\"https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-gemini-2-spatial-reasoning&utm_medium=aRT-clicks&utm_campaign=gemini-2-spatial-reasoning&destination=gemini-2-spatial-reasoning&url=https%3A%2F%2Fconsole.cloud.google.com%2Fvertex-ai%2Fworkbench%2Fdeploy-notebook%3Fdownload_url%3Dhttps%3A%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fmain%2Fgenai-on-vertex-ai%2Fgemini%2Fprompting_recipes%2Fspatial_reasoning%2Fspatial_reasoning_app_for_gemini2.ipynb\">\n",
    "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
    "    </a>\n",
    "  </td>\n",
    "  <td style=\"text-align: center\">\n",
    "    <a href=\"https://github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/prompting_recipes/spatial_reasoning/spatial_reasoning_app_for_gemini2.ipynb\">\n",
    "      <img width=\"32px\" src=\"https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
    "    </a>\n",
    "  </td>\n",
    "</table>\n",
    "\n",
    "<div style=\"clear: both;\"></div>\n",
    "\n",
    "<b>Share to:</b>\n",
    "\n",
    "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/prompting_recipes/spatial_reasoning/spatial_reasoning_app_for_gemini2.ipynb\" target=\"_blank\">\n",
    "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
    "</a>\n",
    "\n",
    "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/prompting_recipes/spatial_reasoning/spatial_reasoning_app_for_gemini2.ipynb\" target=\"_blank\">\n",
    "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
    "</a>\n",
    "\n",
    "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/prompting_recipes/spatial_reasoning/spatial_reasoning_app_for_gemini2.ipynb\" target=\"_blank\">\n",
    "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg\" alt=\"X logo\">\n",
    "</a>\n",
    "\n",
    "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/prompting_recipes/spatial_reasoning/spatial_reasoning_app_for_gemini2.ipynb\" target=\"_blank\">\n",
    "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
    "</a>\n",
    "\n",
    "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/prompting_recipes/spatial_reasoning/spatial_reasoning_app_for_gemini2.ipynb\" target=\"_blank\">\n",
    "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rMo_8ZquBWz-"
   },
   "source": [
    "| | |\n",
    "|-|-|\n",
    "| Author(s) | [Emmanuel Awa](https://github.com/awaemmanuel), [Dennis Kashkin](https://github.com/kashkin)|\n",
    "| Reviewer(s) | [Skander Hannachi](https://github.com/skanderhn), [Rajesh Thallam](https://github.com/rthallam)  |\n",
    "| Last updated | 2024 12 11: Gemini 2.0 Flash Experimental Release  |\n",
    "| | 2024 12 11: Initial Publication |\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9nPDyXDbwdqH"
   },
   "source": [
    "## Explore Object Detection with Gemini 2.0 in Vertex AI\n",
    "\n",
    "This notebook showcases the power of Gemini 2.0 for object detection and spatial understanding using Vertex AI.\n",
    "\n",
    "Using the embedded app, you'll discover how to leverage Gemini to accurately identify and locate objects in images, similar to the example shown below.  \n",
    "\n",
    "<img src=\"https://storage.mtls.cloud.google.com/gemini_assets/splash.png\"/>\n",
    "\n",
    "Feel free to explore different prompt styles to achieve the desired results. You can start with the pre-defined prompts and image provided, or personalize your experience by uploading your own images and switching to 'Custom Prompt' to craft your own.\n",
    "\n",
    "**IMPORTANT NOTICE:** This notebook only showcases functionality available in model name `gemini-2.0-flash-exp`\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dBps2wcrOYmn"
   },
   "source": [
    "## Environment Setup: GCP and Libraries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R5Xep4W9lq-Z"
   },
   "source": [
    "### Install Packages and Restart Runtime (If needed)\n",
    "\n",
    "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel. It confirms if the right packages are already installed and restarts the runtime if needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-lj_qHHETAKF"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import importlib.metadata\n",
    "import time\n",
    "import IPython\n",
    "\n",
    "def check_package_version(package_name, required_version):\n",
    "    try:\n",
    "        installed_version = importlib.metadata.version(package_name)\n",
    "        if installed_version != required_version:\n",
    "            print(f\"Warning: {package_name} {required_version} \"\n",
    "                  f\"required, but {installed_version} is installed.\")\n",
    "            return False  # Indicate version mismatch\n",
    "        return True  # Indicate correct version\n",
    "    except importlib.metadata.PackageNotFoundError:\n",
    "        print(f\"Warning: {package_name} is not installed.\")\n",
    "        return False  # Indicate package not found\n",
    "\n",
    "# List of packages and their required versions\n",
    "packages_to_check = {\n",
    "    'google-cloud-aiplatform': '1.74.0',  # Replace with your desired version\n",
    "    'gradio': '5.8.0',  # Replace with your desired version\n",
    "    # Add more packages and versions as needed\n",
    "}\n",
    "\n",
    "# Check if any required package is missing or has a version mismatch\n",
    "restart_required = False\n",
    "for package_name, required_version in packages_to_check.items():\n",
    "    if not check_package_version(package_name, required_version):\n",
    "        restart_required = True\n",
    "        print(f\"Installing {package_name}=={required_version}\")\n",
    "        !pip install {package_name}=={required_version} --quiet --user\n",
    "\n",
    "# Restart the kernel if necessary\n",
    "if restart_required:\n",
    "    print(\"Restarting kernel...\")\n",
    "    time.sleep(5)  # Add time for the environment to update\n",
    "    app = IPython.Application.instance()\n",
    "    app.kernel.do_shutdown(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HTyJBFIS_zlL"
   },
   "source": [
    "Before running the notebook, you will need to provide the following:\n",
    "\n",
    "*   **GCP PROJECT_ID:** Your Google Cloud Project ID.\n",
    "*   **LOCATION:** The region for your Vertex AI resources (e.g., 'us-central1').\n",
    "\n",
    "Make sure you have a Google Cloud Project with billing enabled before proceeding. You can create a new project or use an existing one.\n",
    "\n",
    "You will be prompted to enter these values in a form below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "w5ANs_li78yr"
   },
   "outputs": [],
   "source": [
    "PROJECT_ID = '[your-project-id-here]' # @param {type: 'string'}\n",
    "LOCATION = 'us-central1' # @param {type: 'string'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i7xttCL-7u6k"
   },
   "source": [
    "###  Authentication  \n",
    "\n",
    "If you're using Colab, run the code in the next cell. Follow the popups and authenticate with an account that has access to your Google Cloud [project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects).\n",
    "\n",
    "If you're running this notebook somewhere besides Colab, make sure your environment has the right Google Cloud access. If that's a new concept to you, consider looking into [Application Default Credentials for your local environment](https://cloud.google.com/docs/authentication/provide-credentials-adc#local-dev) and [initializing the Google Cloud CLI](https://cloud.google.com/docs/authentication/gcloud). In many cases, running `gcloud auth application-default login` in a shell on the machine running the notebook kernel is sufficient.\n",
    "\n",
    "More authentication options are discussed [here](https://cloud.google.com/docs/authentication)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CYgFR7ro7yJX"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    from google.colab import auth\n",
    "    auth.authenticate_user()\n",
    "    print('Authenticated')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "09a650a2-edd4-4c05-920e-99005de2345e"
   },
   "source": [
    "### Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RtqcR4pYcD7M"
   },
   "outputs": [],
   "source": [
    "import base64\n",
    "import hashlib\n",
    "import json\n",
    "import os\n",
    "import sys\n",
    "import re\n",
    "import gradio as gr\n",
    "from google.colab import auth, userdata\n",
    "from typing import Optional, Union\n",
    "from PIL import Image as PILImage\n",
    "from PIL import ImageDraw, ImageColor, ImageFont, UnidentifiedImageError\n",
    "\n",
    "import vertexai\n",
    "from vertexai.generative_models import (GenerativeModel,\n",
    "                                        HarmBlockThreshold,\n",
    "                                        HarmCategory,\n",
    "                                        Part)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b3qAr1TdRrKc"
   },
   "outputs": [],
   "source": [
    "MODEL_NAME = 'gemini-2.0-flash-exp'\n",
    "vertexai.init(project=PROJECT_ID, location=LOCATION)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Gx1eo6ARP5HQ"
   },
   "source": [
    "# Leveraging Gemini for Bounding Boxes\n",
    "\n",
    "This section demonstrates how to utilize the power of Google's Gemini model to identify and extract bounding boxes of objects within images. We'll explore prompt engineering techniques to guide Gemini in accurately detecting desired elements, and then visualize the results by overlaying the bounding boxes onto the original image. This showcases Gemini's capabilities for object detection and its potential applications in various computer vision tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PX2x2xkZLJpC"
   },
   "source": [
    "### Bounding Box Data Class\n",
    "\n",
    "Define a object that represents a bounding box with 4 coordinates in Gemini format where X and Y are on the 0 to 1000 scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IsEEKkrXLV_M"
   },
   "outputs": [],
   "source": [
    "class BoundingBox():\n",
    "    \"\"\"Create a BoundingBox in Gemini format (X and Y are on 0..1000 scale)\"\"\"\n",
    "\n",
    "    def __init__(self, top: int, left: int, bottom: int, right: int, label: str | None = None):\n",
    "\n",
    "        if None in [top, left, bottom, right]:\n",
    "            raise ValueError(f\"BoundingBox requires all coordinates to be set.\")\n",
    "        if top < 0 or top > 1000: raise ValueError(f'ymin must be an integer between 0 and 1000')\n",
    "        if left < 0: raise ValueError(f'xmin must be an integer between 0 and 1000')\n",
    "        if bottom < 0: raise ValueError(f'ymax must be an integer between 0 and 1000')\n",
    "        if right < 0: raise ValueError(f'xmax must be an integer between 0 and 1000')\n",
    "        if right <= left: raise ValueError(f'xmax must be greater than xmin (right={right}, left={left})')\n",
    "        if bottom <= top: raise ValueError(f'ymax must be greater than ymin (bottom={bottom}, top={top})')\n",
    "        self.left = left\n",
    "        self.right = right\n",
    "        self.top = top\n",
    "        self.bottom = bottom\n",
    "        self.label = label\n",
    "        signature = label or f'{top}-{left}-{bottom}-{right}'\n",
    "        int_hash = int(hashlib.sha256(signature.encode('utf-8')).hexdigest(), 16) % (10 ** 8)\n",
    "        colors = list(ImageColor.colormap.keys())\n",
    "        colors = [color for color in colors if color != 'grey'] # Reserve grey for borders\n",
    "        stable_color_index = int_hash % len(colors)\n",
    "        self.color = colors[stable_color_index]\n",
    "        print(f'Stable color index: {stable_color_index} based on int_hash={int_hash}: {self.color}, {len(colors)} colors')\n",
    "\n",
    "    def __repr__(self):\n",
    "        \"\"\"Return a string representation of the bounding box.\"\"\"\n",
    "        return f'TLBR[{self.top}, {self.left}, {self.bottom}, {self.right}]: {self.label or \"\"} #{self.color}'\n",
    "\n",
    "    @staticmethod\n",
    "    def is_numeric(value: str) -> bool:\n",
    "        \"\"\"Check if a string is a number.\"\"\"\n",
    "        return value.strip().lstrip('-').replace('.', '', 1).isdigit()\n",
    "\n",
    "    @staticmethod\n",
    "    def from_markdown(text: str) -> Union['BoundingBox', None]:\n",
    "        \"\"\"Create a bounding box from a markdown string.\"\"\"\n",
    "        if not text:\n",
    "            return None\n",
    "        for line in text.strip().splitlines():\n",
    "            line = line.strip().lstrip('-').strip()\n",
    "            # Extract the numbers from the line after removing brackets and splitting by comma\n",
    "            if '[' in line and ']' in line:\n",
    "                numbers = line.split('[')[1].split(']')[0].split(',')\n",
    "            else:\n",
    "                numbers =  line.split(',')\n",
    "            if len(numbers) != 4:\n",
    "                print(f'Skipping response line with {len(numbers)} comma separated parts instead of 4: {text}')\n",
    "                continue\n",
    "            ints = [int(num.strip()) for num in numbers if BoundingBox.is_numeric(num)]\n",
    "            if len(ints) != 4:\n",
    "                print(f'Skipping response line with only {len(numbers)} comma separated numbers instead of 4: {text}')\n",
    "                continue\n",
    "            return BoundingBox(ints[0], ints[1], ints[2], ints[3]) # Using the first bounding box (even if the model returns multiple)\n",
    "        return None\n",
    "\n",
    "    @staticmethod\n",
    "    def from_list_of_ints(array, label: str | None = None) -> 'BoundingBox':\n",
    "        \"\"\"Create a bounding box from a list of integers.\"\"\"\n",
    "        if not isinstance(array, list):\n",
    "            raise ValueError(f'Model returned unexpected JSON structure for bounding box coordinates: {json.dumps(array)}')\n",
    "        for coordinate in array:\n",
    "            if not isinstance(coordinate, int):\n",
    "                raise ValueError(f'Model returned unrecognized JSON bounding box coordinate: {coordinate}')\n",
    "        return BoundingBox(top=array[0], left=array[1], bottom=array[2], right=array[3], label=label)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BWPkWjrXakqy"
   },
   "source": [
    "### Utilities - Preprocessing and Postprocessing\n",
    "\n",
    "Some helper functions for preprocessing and postprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8FoG_3N9b-PP"
   },
   "outputs": [],
   "source": [
    "def read_image(local_path_to_image_file: str) -> PILImage.Image:\n",
    "    \"\"\"\n",
    "    Reads an image from a local file path.\n",
    "\n",
    "    Args:\n",
    "        local_path_to_image_file: Path to the local image file.\n",
    "\n",
    "    Returns:\n",
    "        The image as a PIL Image object.\n",
    "    \"\"\"\n",
    "    return PILImage.open(local_path_to_image_file)\n",
    "\n",
    "def encode_image_for_gemini(local_path_to_image_file: str):\n",
    "    \"\"\"\n",
    "    Encodes an image for Gemini.\n",
    "\n",
    "    Args:\n",
    "        local_path_to_image_file: Path to the local image file.\n",
    "\n",
    "    Returns:\n",
    "        The encoded image as a Part object.\n",
    "    \"\"\"\n",
    "    encoded_image = base64.b64encode(open(local_path_to_image_file, 'rb').read()).decode('utf-8')\n",
    "    return Part.from_data(data=base64.b64decode(encoded_image), mime_type='image/jpeg')\n",
    "\n",
    "def strip_json_code_block(text: str) -> str:\n",
    "    \"\"\"Strips the ```json code block markers from a string and returns the JSON.\n",
    "\n",
    "    Args:\n",
    "        text: The input string containing the JSON code block.\n",
    "\n",
    "    Returns:\n",
    "        The extracted JSON string.\n",
    "    \"\"\"\n",
    "    pattern = r\"```json\\s*(.*?)\\s*```\"  # Matches ```json ... ``` with optional whitespace\n",
    "    match = re.search(pattern, text, re.DOTALL)\n",
    "    if match:\n",
    "        return match.group(1).strip()\n",
    "    else:\n",
    "        return text  # Return original text if no code block is found\n",
    "\n",
    "def download_image_from_gcs(gcs_uri: str, local_path: str) -> None:\n",
    "    \"\"\"Downloads an image from Google Cloud Storage (GCS) to a local file.\n",
    "\n",
    "    Args:\n",
    "        gcs_uri: The GCS URI of the image to download (e.g., 'gs://bucket-name/image.jpg').\n",
    "        local_path: The local path where the downloaded image will be saved (e.g., './image.jpg').\n",
    "    \"\"\"\n",
    "    if not os.path.exists(local_path):\n",
    "        print(f'Local path to image file does not exist: {local_path}')\n",
    "        print(f'Downloading sample image from GCS...')\n",
    "        !gsutil cp \"{gcs_uri}\" \"{local_path}\"\n",
    "    else:\n",
    "        print(f'Image already exists at: {local_path}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cJX6G-s55981"
   },
   "source": [
    "### Generating Bounding Boxes with Gemini\n",
    "\n",
    "This section focuses on generating bounding boxes using the Gemini model.\n",
    "It involves the following steps:\n",
    "\n",
    "1. **Defining the `generate_bounding_boxes` function:** This function handles the interaction with the Gemini API, sending the image and prompt and receiving the predicted bounding boxes.\n",
    "2. **Setting Generation Parameters:** We'll define parameters like temperature and top_p to control the model's creativity and diversity.\n",
    "3. **Generating Results:** We'll use the defined function and parameters to obtain bounding box predictions from Gemini for a given image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FFtrCcep5XbC"
   },
   "outputs": [],
   "source": [
    "def generate_bounding_boxes(\n",
    "    model_name: str,\n",
    "    system_instrs: str,\n",
    "    user_prompt: str,\n",
    "    local_path_to_image_file: str,\n",
    "    generation_config: Optional[dict] = None,\n",
    "    safety_settings: Optional[dict] = None\n",
    ") -> list:\n",
    "    \"\"\"\n",
    "    Sends a message to Gemini and returns the model's response for bounding boxes.\n",
    "\n",
    "    Args:\n",
    "        model_name: The name of the generative model to use.\n",
    "        system_in: System-level instructions for the model.\n",
    "        user_message: The user's message to send to the model.\n",
    "        local_path_to_image_file: Path to the local image file.\n",
    "        generation_config: (Optional) Configuration for the model's generation process.\n",
    "        safety_settings: (Optional) Safety settings for the model.\n",
    "\n",
    "    Returns:\n",
    "        A list of bounding boxes, where each box is represented as a dictionary\n",
    "        with keys like 'x', 'y', 'width', 'height', and 'label'.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: If the image file path is invalid or the model returns an unexpected response.\n",
    "    \"\"\"\n",
    "    if not os.path.exists(local_path_to_image_file):\n",
    "        raise ValueError(f'Local path to image file does not exist: {local_path_to_image_file}')\n",
    "\n",
    "    generation_config = generation_config or {\n",
    "        'temperature': 0.36,\n",
    "        'top_p': 1.0,\n",
    "        'top_k': 40,\n",
    "        'max_output_tokens': 8192,\n",
    "        'candidate_count': 1,\n",
    "    }\n",
    "    safety_settings = safety_settings or {\n",
    "        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
    "        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
    "        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
    "        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
    "    }\n",
    "\n",
    "    model = GenerativeModel(\n",
    "        model_name=model_name,\n",
    "        system_instruction=system_instrs,\n",
    "        generation_config=generation_config,\n",
    "        safety_settings=safety_settings,\n",
    "    )\n",
    "    response = model.generate_content(contents=[user_prompt, encode_image_for_gemini(local_path_to_image_file)], stream=False)\n",
    "    try:\n",
    "        bounding_boxes_list = json.loads(strip_json_code_block(response.text))\n",
    "    except json.JSONDecodeError as e:\n",
    "        raise ValueError(f\"Error decoding JSON response from model: {e}\")\n",
    "\n",
    "    if not isinstance(bounding_boxes_list, list):\n",
    "        raise ValueError(f'Model returned unexpected JSON structure instead of an array of objects: {json.dumps(bounding_boxes_list)}')\n",
    "\n",
    "    return bounding_boxes_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4_zvEZNlTEwD"
   },
   "source": [
    "### Utilities - Visualizing Bounding Boxes with Plotting and Rendering\n",
    "\n",
    "This section introduces a set of utility functions designed to visualize the bounding boxes identified by Gemini. These functions handle tasks such as reading images, plotting bounding boxes with distinct colors and labels, and rendering the final output with overlaid boxes onto the source image. These utilities streamline the process of visualizing object detection results and provide a clear representation of Gemini's capabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HhKWCbazCBGq"
   },
   "outputs": [],
   "source": [
    "def plot_bounding_boxes(\n",
    "        image: PILImage.Image, bounding_boxes: list\n",
    "    ) -> PILImage.Image:\n",
    "        \"\"\"\n",
    "        Overlays bounding boxes on an image with random colors and optional labels.\n",
    "        BoundingBoxes in Gemini format (X and Y are on 0..1000 scale) are auto scaled to the actual image size.\n",
    "\n",
    "        Args:\n",
    "            image: source image\n",
    "            bounding_boxes: a list with one or more bounding boxes\n",
    "\n",
    "        Returns: a new PIL Image object with bounding boxes\n",
    "        \"\"\"\n",
    "        if not image:\n",
    "            raise ValueError(\"image is required\")\n",
    "        if not bounding_boxes:\n",
    "            return image\n",
    "        draw = ImageDraw.Draw(image)\n",
    "        font = ImageFont.load_default(size=15)\n",
    "        for bb in bounding_boxes:\n",
    "            # Convert Gemini coordinates from Gemini scale (0..1000) to absolute pixels\n",
    "            left = int(bb.left / 1000 * image.width)\n",
    "            top = int(bb.top / 1000 * image.height)\n",
    "            right = int(bb.right / 1000 * image.width)\n",
    "            bottom = int(bb.bottom / 1000 * image.height)\n",
    "            if right < left + 8:\n",
    "                right = left + 8 # we cannot fit the border\n",
    "            if bottom < top + 8:\n",
    "                bottom = top + 8\n",
    "            # Border line style: 1 pixel grey + 3 pixels box specific color + 1 pixel grey\n",
    "            draw.rectangle(((left, top), (right, bottom)), outline=ImageColor.colormap['grey'], width=1)\n",
    "            draw.rectangle(((left+1, top+1), (right-1, bottom-1)), outline=bb.color, width=3)\n",
    "            draw.rectangle(((left+4, top+4), (right-4, bottom-4)), outline=ImageColor.colormap['grey'], width=1)\n",
    "            if bb.label:\n",
    "                # Check if the label fits inside of the bounding box\n",
    "                label_left, label_top, label_right, label_bottom = font.getbbox(bb.label)\n",
    "                print(f'label coordinates: label_left={label_left}, label_top={label_top}, label_right={label_right}, label_bottom={label_bottom}')\n",
    "                is_box_wide_enough = (label_right - label_left) + 8 < (right - left)\n",
    "                is_box_tall_enough = (label_bottom - label_top) + 6 < (bottom - top)\n",
    "                print(f'{bb} is_box_wide_enough={is_box_wide_enough}, is_box_tall_enough={is_box_tall_enough}')\n",
    "                if is_box_wide_enough and is_box_tall_enough:\n",
    "                    label_offset_x = 7\n",
    "                    label_offset_y = 3\n",
    "                else: # Print the label below the bounding box\n",
    "                    label_offset_x = 0\n",
    "                    label_offset_y = bottom - top\n",
    "                    print(f'label_offset_y={label_offset_y}')\n",
    "                draw.text((left + label_offset_x, top + label_offset_y), bb.label, fill=ImageColor.colormap['red'], font=font)\n",
    "        return image\n",
    "\n",
    "\n",
    "def render_predicted_bounding_boxes(predictions: list, source_image_path: str, result_image_path: str) -> None:\n",
    "    \"\"\"Renders predicted bounding boxes onto an image and saves the result.\n",
    "\n",
    "    This function takes a list of predictions from a model,\n",
    "    reads the source image, extracts bounding box information\n",
    "    from the predictions, overlays the boxes onto the image,\n",
    "    and saves the resulting image to the specified path.\n",
    "\n",
    "    Args:\n",
    "        predictions: A list of predictions from the model, expected to\n",
    "                     contain bounding box data and optional labels.\n",
    "        source_image_path: The path to the source image file.\n",
    "        result_image_path: The path where the resulting image\n",
    "                           with bounding boxes will be saved.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: If the predictions are not in the expected format,\n",
    "                    or if required attributes are missing.\n",
    "\n",
    "    Returns:\n",
    "        None. The function saves the resulting image to the\n",
    "        specified path.\n",
    "    \"\"\"\n",
    "    image = read_image(source_image_path)\n",
    "    boxes: list[BoundingBox] = []\n",
    "    if not isinstance(predictions, list):\n",
    "        raise ValueError(f'Model returned unexpected JSON structure instead of an array of objects: {json.dumps(predictions)}')\n",
    "    for item in predictions:\n",
    "        if not isinstance(item, dict):\n",
    "            raise ValueError(f'Model returned unexpected array item: {json.dumps(item)}')\n",
    "        if 'box_2d' not in item:\n",
    "            raise ValueError(f'Model returned bounding box with missing attribute \"box_2d\": {json.dumps(item)}')\n",
    "        label = item['label'] if 'label' in item else None\n",
    "        box_2d = item['box_2d']\n",
    "        if not isinstance(box_2d, list):\n",
    "            raise ValueError(f'Model returned unexpected box_2d value instead of an array: {json.dumps(box_2d)}')\n",
    "        boxes.append(BoundingBox.from_list_of_ints(box_2d, label=label))\n",
    "    class_labels = {box.label for box in boxes}\n",
    "    if len(class_labels) == 1: # no need to display identical labels\n",
    "        for box in boxes:\n",
    "            box.label = None\n",
    "    result = plot_bounding_boxes(image, boxes)\n",
    "    result.save(result_image_path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4euiwfKw56SU"
   },
   "source": [
    "### Prompt Templates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mqGUFN6ZcLX4"
   },
   "outputs": [],
   "source": [
    "SYSTEM_INSTRUCTIONS = '''\n",
    "Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.\n",
    "If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).\n",
    "'''\n",
    "\n",
    "PROMPT_SINGLE_OBJECT = 'Could you display the bounding boxes around the Ferris wheel.'\n",
    "\n",
    "PROMPT_SINGLE_CLASS = 'Give me the bounding boxes for all the kites in the park.'\n",
    "\n",
    "PROMPT_MULTIPLE_CLASSES = 'What are the regions defined by the bounding boxes for two types of animals: cats and dogs.'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rRa7ouT_Trjp"
   },
   "source": [
    "### Testing the Bounding Box Generation\n",
    "\n",
    "Now that we've defined the `generate_bounding_boxes` function and set the generation parameters, let's test it with a sample image and prompt. This will help us verify that the model is correctly identifying and returning bounding boxes.\n",
    "\n",
    "**NOTE:** The test on the next cell assumes you have uploaded a sample image to the Colab filesystem and updated `sample_image_path` with the correct file name below. For the purposes of a seamless experiment, we've uploaded a sample image to [GCS](https://storage.mtls.cloud.google.com/gemini_assets/park.jpg).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "a01727OP2xkB"
   },
   "outputs": [],
   "source": [
    "model_name = MODEL_NAME\n",
    "GCS_IMAGE_URI = 'gs://public-aaie-genai-samples/gemini_2_0/spatial_understanding/park.jpg'\n",
    "local_image_path = './park.jpg'\n",
    "download_image_from_gcs(GCS_IMAGE_URI, local_image_path)\n",
    "results = generate_bounding_boxes(model_name,\n",
    "                                  SYSTEM_INSTRUCTIONS,\n",
    "                                  PROMPT_SINGLE_CLASS,\n",
    "                                  local_path_to_image_file=local_image_path)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MV77n6jb1gOF"
   },
   "source": [
    "### Interactive Visualization with Gradio\n",
    "\n",
    "This section integrates bounding box generation into a Gradio application, enabling you to interactively visualize object detection results on uploaded images.\n",
    "\n",
    "**Predefined Prompts:**\n",
    "\n",
    "Start by exploring object detection with our predefined prompts using the provided sample image.  \n",
    "\n",
    "**Custom Prompts:**\n",
    "\n",
    "Switch to \"Custom Prompt\" to unlock the full potential of Gemini.  Experiment with your own prompts, such as precisely locating specific objects within an image and retrieving their bounding box information. For example, you can try prompts like \"Find the red car\" or \"Where are the bicycles?\". Feel free to upload your own images and tailor your prompts for personalized exploration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yIb2lfS7Kf6L"
   },
   "outputs": [],
   "source": [
    "# @title Download Defaut Image for app\n",
    "GCS_IMAGE_URI = 'gs://public-aaie-genai-samples/gemini_2_0/spatial_understanding/park.jpg'\n",
    "DEFAULT_IMAGE_PATH = './park.jpg'\n",
    "download_image_from_gcs(GCS_IMAGE_URI, DEFAULT_IMAGE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xj2sH0HDYoo5"
   },
   "outputs": [],
   "source": [
    "# @title Image Processor\n",
    "def process_image(file_name: str, user_prompt: str = PROMPT_SINGLE_CLASS):\n",
    "    \"\"\"\n",
    "    Takes an input image uploaded from local disk, uploads it to colab and processes it\n",
    "    \"\"\"\n",
    "    try:\n",
    "        current_dir = os.getcwd()\n",
    "        base_name = os.path.basename(file_name)\n",
    "        save_path = os.path.join(current_dir, base_name)\n",
    "        image = PILImage.open(file_name)\n",
    "        image.save(save_path)\n",
    "        message = f'Image saved as {save_path} in the current directory.'\n",
    "        print(message)\n",
    "        try:\n",
    "            results = generate_bounding_boxes(MODEL_NAME, SYSTEM_INSTRUCTIONS, user_prompt=user_prompt, local_path_to_image_file=save_path)\n",
    "        except Exception as e:\n",
    "            error_message = f\"Error generating bounding boxes: {e}\"\n",
    "            raise gr.Error(error_message)\n",
    "        bb_save_path = os.path.join(current_dir, f'{base_name}_bb.jpg')\n",
    "        render_predicted_bounding_boxes(results, save_path, bb_save_path)\n",
    "        return PILImage.open(bb_save_path)\n",
    "    except FileNotFoundError:\n",
    "        raise gr.Error(f\"Error: Image file not found at {file_name}\")\n",
    "    except UnidentifiedImageError:\n",
    "        raise gr.Error(f\"Error: Could not open or read image file {file_name}\")\n",
    "    except Exception as e:  # Catch any other unexpected errors\n",
    "        raise gr.Error(f\"An unexpected error occurred during processing: {e}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aiJUHmLn3b8c"
   },
   "outputs": [],
   "source": [
    "# @title Main Gradio application\n",
    "with gr.Blocks(title=\"BoxIt With Gemini 2.0\") as demo:\n",
    "    gr.Markdown('# **BoxIt**')\n",
    "\n",
    "    with gr.Row():\n",
    "        image_display = gr.Image(type='filepath', label='Image', value=DEFAULT_IMAGE_PATH)\n",
    "\n",
    "        with gr.Column():\n",
    "            prompt_type = gr.Radio(\n",
    "                choices=['Predefined Prompts', 'Custom Prompt'],\n",
    "                label='Select Prompt Style',\n",
    "            )\n",
    "            predefined_prompts = [PROMPT_SINGLE_OBJECT, PROMPT_SINGLE_CLASS, PROMPT_MULTIPLE_CLASSES]\n",
    "            prompt_dropdown = gr.Dropdown(\n",
    "                choices=predefined_prompts,\n",
    "                label='Choose a predefined prompt:',\n",
    "                visible=False,\n",
    "            )\n",
    "            custom_prompt = gr.Textbox(\n",
    "                lines=2,\n",
    "                label='Enter your prompt',\n",
    "                visible=False,\n",
    "            )\n",
    "            submit_btn = gr.Button('Find Bounding Boxes')\n",
    "            # bounding_box_output = gr.Textbox(label=\"Bounding Boxes\", visible=False)\n",
    "\n",
    "    original_image = gr.State(DEFAULT_IMAGE_PATH)  # Store the *original* uploaded image\n",
    "\n",
    "    def toggle_prompt_input(prompt_type, original_img):\n",
    "        if original_img is not None:\n",
    "            if prompt_type == 'Predefined Prompts':\n",
    "                return gr.update(visible=True), gr.update(visible=False), original_img\n",
    "            elif prompt_type == 'Custom Prompt':\n",
    "                return gr.update(visible=False), gr.update(visible=True), original_img\n",
    "            else:  # \"Select Prompt Style\"\n",
    "                return gr.update(visible=False), gr.update(visible=False), original_img\n",
    "        else:\n",
    "            return gr.update(visible=False), gr.update(visible=False), gr.update()\n",
    "\n",
    "    prompt_type.change(\n",
    "        fn=toggle_prompt_input,\n",
    "        inputs=[prompt_type, original_image],\n",
    "        outputs=[prompt_dropdown, custom_prompt, image_display],\n",
    "    )\n",
    "\n",
    "    def process_and_display(img, prompt_type, selected_prompt, custom_prompt):\n",
    "        if not img:\n",
    "            return gr.update()\n",
    "\n",
    "        if prompt_type == 'Predefined Prompts':\n",
    "            prompt = selected_prompt\n",
    "        elif prompt_type == 'Custom Prompt':\n",
    "            prompt = custom_prompt\n",
    "        else:  # \"Select Prompt Style\" - Do nothing\n",
    "            return img\n",
    "        print(f'Prompt: {prompt}')\n",
    "        print(f'Processing image: {img}')\n",
    "        try:\n",
    "            processed_image = process_image(file_name=img, user_prompt=prompt)\n",
    "            return processed_image\n",
    "        except Exception as e:\n",
    "            print(f'Error processing image: {e}')\n",
    "            return img  # Return the original image in case of an error\n",
    "\n",
    "    image_display.upload(lambda x: x, inputs=image_display, outputs=original_image)\n",
    "\n",
    "    submit_btn.click(\n",
    "        fn=process_and_display,\n",
    "        inputs=[original_image, prompt_type, prompt_dropdown, custom_prompt],\n",
    "        outputs=image_display\n",
    "    )\n",
    "\n",
    "    demo.queue()\n",
    "    demo.launch(quiet=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pML5Po1qpjI-"
   },
   "source": [
    "# Summary\n",
    "\n",
    "This notebook demonstrates how to use Google's Gemini model to identify and extract bounding boxes of elements within an image. It covers the following key aspects:\n",
    "\n",
    "1. **Environment Setup:** Setting up the Google Cloud Project, installing necessary libraries, and authenticating with Google Cloud.\n",
    "2. **Leveraging Gemini for Bounding Boxes:** Utilizing the Gemini model to generate bounding boxes for specific objects or classes within an image.\n",
    "3. **Prompt Engineering:** Defining prompt templates to guide the Gemini model in accurately detecting the desired elements.\n",
    "4. **Bounding Box Data Class:** Defining a class to represent bounding boxes in the Gemini format, along with methods for parsing and processing them.\n",
    "5. **Utilities:** Helper functions for reading, encoding, and plotting bounding boxes onto images.\n",
    "6. **Visualizing Bounding Boxes:** Displaying the predicted bounding boxes overlaid on the source image for clear visualization.\n",
    "7. **Interactive Interface:** Building a Gradio interface to allow users to upload images, generate bounding boxes, and visualize the results interactively.\n",
    "\n",
    "The notebook showcases the power of Gemini for object detection tasks and provides a practical example of its potential applications in computer vision. It offers a comprehensive guide to using Gemini, from setup to interactive visualization."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
